From: Jakub Kuderski Date: Mon, 9 Jan 2023 16:35:46 +0000 (-0500) Subject: [PATCH] [mlir][spirv] Account for type conversion failures in scf-to-spirv X-Git-Tag: archive/raspbian/1%15.0.7-10+rpi1~1^2~3 X-Git-Url: https://dgit.raspbian.org/%22http://www.example.com/cgi/%22/%22http:/www.example.com/cgi/%22?a=commitdiff_plain;h=19144774c70e9967594ccbde516692e8427c2003;p=llvm-toolchain-15.git [PATCH] [mlir][spirv] Account for type conversion failures in scf-to-spirv Fixes: https://github.com/llvm/llvm-project/issues/59136 Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D141292 Gbp-Pq: Name CVE-2023-29934.patch --- diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp index 10623b63d9..81c521d9da 100644 --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/FormatVariadic.h" using namespace mlir; @@ -286,6 +287,10 @@ IfOpConversion::matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor, SmallVector returnTypes; for (auto result : ifOp.getResults()) { auto convertedType = typeConverter.convertType(result.getType()); + if (!convertedType) + return rewriter.notifyMatchFailure( + loc, llvm::formatv("failed to convert type '{0}'", result.getType())); + returnTypes.push_back(convertedType); } replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext, diff --git a/mlir/test/Conversion/SCFToSPIRV/if.mlir b/mlir/test/Conversion/SCFToSPIRV/if.mlir index f937ac6c4e..79f53aaa8f 100644 --- a/mlir/test/Conversion/SCFToSPIRV/if.mlir +++ b/mlir/test/Conversion/SCFToSPIRV/if.mlir @@ -153,4 +153,18 @@ func.func @simple_if_yield_type_change(%arg2 : memref<10xf32>, %arg3 : memref<10 return } +// Memrefs without a spirv storage class are not supported. The conversion +// should preserve the `scf.if` and not crash. +func.func @unsupported_yield_type(%arg0 : memref<8xi32>, %arg1 : memref<8xi32>, %c : i1) { +// CHECK-LABEL: @unsupported_yield_type +// CHECK-NEXT: scf.if +// CHECK: spirv.Return + %r = scf.if %c -> (memref<8xi32>) { + scf.yield %arg0 : memref<8xi32> + } else { + scf.yield %arg1 : memref<8xi32> + } + return +} + } // end module